from pathlib import Path
import efel
import matplotlib.pyplot as plt
from datareuse import Reuse
import numpy as np
import sys
import pandas as pd
from bluepyparallel import evaluate
from matplotlib.backends.backend_pdf import PdfPages

import pyabf
import json

import seaborn as sns
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score
import shap

from xgboost import XGBClassifier
from sklearn.model_selection import RepeatedStratifiedKFold
from tqdm import tqdm


def _get_shap_feature_importance(shap_values):
    """From a list of shap values per folds, compute the global shap feature importance."""
    mean_shap_values = np.mean(shap_values, axis=0)
    if len(np.shape(mean_shap_values)) > 2:
        global_mean_shap_values = np.mean(mean_shap_values, axis=0)
        mean_shap_values = list(mean_shap_values)
    else:
        global_mean_shap_values = mean_shap_values

    shap_feature_importance = np.mean(abs(global_mean_shap_values), axis=0)
    return mean_shap_values, shap_feature_importance


def get_max_rb_df(df):
    df_by_cell = df[["month", "day", "cell"]].drop_duplicates()
    max_rb_df = pd.DataFrame(columns=df.columns)
    for idx, row in df_by_cell.iterrows():
        month, day, cell = row["month"], row["day"], row["cell"]
        cell_df = df[(df["month"] == month) & (df["day"] == day) & (df["cell"] == cell)]

        max_rb_rw = cell_df.loc[cell_df["all_burst_number"].idxmax()]
        new_df = pd.DataFrame([max_rb_rw])
        max_rb_df = pd.concat([max_rb_df, new_df], ignore_index=True)

    return max_rb_df


def train(X, y, n_splits=5, n_repeats=1):
    model = XGBClassifier()
    folds = RepeatedStratifiedKFold(n_splits=n_splits, n_repeats=n_repeats, random_state=42)

    acc_scores = []
    shap_values = []
    for indices in tqdm(folds.split(X, y=y), total=n_splits * n_repeats):
        train_index, val_index = indices
        model.fit(X.iloc[train_index], y.iloc[train_index])
        model_pred = model.predict(X.iloc[val_index])

        acc_score = accuracy_score(y.iloc[val_index], model_pred)
        acc_scores.append(acc_score)

        explainer = shap.TreeExplainer(model)
        shap_value = explainer.shap_values(X)
        shap_values.append(shap_value)

    return model, acc_scores, _get_shap_feature_importance(shap_values)[0]


def classify_and_explain(data_df, columns_to_drop=None, splits=10, repeats=10):
    y = pd.DataFrame(data_df.index)
    X = data_df.copy().reset_index(drop=True)
    if columns_to_drop is not None:
        X = X.drop(columns=columns_to_drop, axis=1)

    model, acc_scores, shap_values = train(X, y, n_splits=splits, n_repeats=repeats)
    accuracy = np.round(np.mean(acc_scores) * 100, 2)
    error = np.round(np.std(acc_scores) * 100, 2)

    print(f"mean accuracy is {accuracy}")
    print(f"mean accuracy stdev is  {error}")

    shap.summary_plot(shap_values, X, plot_type="bar", show=False, plot_size=(5, 3))


if __name__ == "__main__":
    data_df = pd.read_csv("../figure_1/feature_data_df_clean.csv")
    data_df = data_df[data_df["tonic_after_burst"] == 0]

    # find max rb trace per cell and classify with that
    data_df = get_max_rb_df(data_df)
    data_df = data_df[data_df["postburst_min_values"] < 0]
    data_df.index = data_df["cell_type"].map({"ecel": 0, "spp": 1, "runaway": 1})
    cols = [
        "all_burst_number",
        "runaway",
        "spike_width2",
        "peak_voltage",
        "spikes_per_burst",
        "burst_mean_freq",
        "inv_first_ISI",
        "AP2_AP1_peak_diff",
        "AHP_depth_abs",
        "time_to_first_spike",
        "postburst_min_values",
    ]
    data_df = data_df[cols]

    # using all features
    classify_and_explain(data_df)
    plt.savefig("all_features.pdf")
    plt.close()

    # using all features but without burst number
    columns_to_drop = ["all_burst_number"]
    classify_and_explain(data_df, columns_to_drop)
    plt.savefig("no_burst_number.pdf")
    plt.close()

    # using top 3 features
    columns_to_drop = [
        col for col in cols if col not in ["all_burst_number", "burst_mean_freq", "spike_width2"]
    ]
    classify_and_explain(data_df, columns_to_drop)
    plt.savefig("top_three.pdf")
